Skip to content

[None][feat] dual-pool KV cache with SWA block eviction for gemma4#12813

Open
suyoggupta wants to merge 16 commits intoNVIDIA:mainfrom
nv-auto-deploy:sg/swa
Open

[None][feat] dual-pool KV cache with SWA block eviction for gemma4#12813
suyoggupta wants to merge 16 commits intoNVIDIA:mainfrom
nv-auto-deploy:sg/swa

Conversation

@suyoggupta
Copy link
Copy Markdown
Collaborator

@suyoggupta suyoggupta commented Apr 7, 2026

Summary

  • Adds dual-pool KV cache architecture for models with mixed attention head dimensions (e.g., gemma4-26B: head_dim=256 sliding + head_dim=512 full attention)
  • Each head_dim group gets its own KVCacheManager pool with independent max_attention_window, enabling SWA block eviction during decode
  • MultiPoolKVCacheManager wrapper provides unified API for lifecycle, scheduling, and block retrieval
  • kv_page_offset support in triton write and context kernels for correct windowed cache_loc indexing
  • C++ get_num_front_blocks_removed binding for SWA eviction tracking
  • MMLU 75.6% and GSM8k 91.4% — matching baseline accuracy

Test plan

  • E2E build_and_run_ad.py with gemma4-26B-A4B-it chat template — coherent output
  • MMLU accuracy test (TestGemma4MoE::test_bf16) — 75.6% matching baseline 75.4%
  • GSM8k accuracy — 91.4% matching baseline 91.1%
  • SWA eviction verified via log: SWA eviction: group=0 window=1024 ... evicted=N
  • 7 new unit tests for dual-pool architecture (all passing)
  • Pre-existing unit tests unaffected

Stacked on #12710
🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for Gemma 4 MoE and Gemma 3n models with AutoDeploy integration and documentation.
    • Added sliding-window attention support for memory-efficient inference.
    • Added shared KV cache optimization for model layers.
    • Extended multi-pool KV cache manager for variable sequence window attention.
  • Documentation

    • Added Gemma 4 TensorRT-LLM cookbook with end-to-end AutoDeploy workflow.
  • Tests

    • Added comprehensive test coverage for Gemma models, shared KV attention, and sliding-window behaviors.

bmarimuthu-nv and others added 15 commits April 6, 2026 15:59
…IA#12205)

Adds Gemma3n custom model with shared KV attention, sliding window attention,
and related attention backend changes for AutoDeploy.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Adds Gemma4 (MoE) custom model for AutoDeploy with:
- Custom modeling code supporting K=V attention, proportional RoPE,
  parallel dense+MoE, per-layer scalars, and logit softcapping
- Gelu activation support in torch_moe for Gemma4 MoE layers
- Hierarchical equivalence tests
- Model registry config (triton_paged attention backend for head_dim=512)

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
…nd tests

- Remove incorrect +1.0 scale_shift from Gemma4RMSNorm. HF transformers
  5.5.0 stores effective norm weights directly in the checkpoint; the
  previous implementation incorrectly added 1.0 at load time, causing
  compounding numerical drift across layers and garbled generation.
- Add google/gemma-4-26B-A4B base model registry entry with
  gemma4_moe_base.yaml config.
- Strengthen test_full_model_equivalence with end-to-end logits
  comparison against standalone reference model.
- Add export functional equivalence assertion (pre-export vs post-export).
- Update reference _RefRMSNorm to match corrected norm semantics.
- Update MoE block test to manually unfuse weights (hook now on decoder
  layer, not MoE block).

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…ked prefill

Add piecewise CUDA graph compilation, expanded batch sizes, chunked
prefill, and KV cache config to both gemma4_moe.yaml and
gemma4_moe_base.yaml.

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Add sliding window attention to both decode (FlashDecoding) and
context/prefill kernels. When sliding_window is set, queries only
attend to the most recent W KV tokens, enabling efficient long-context
inference for models with sliding window attention (e.g. Mistral).

Key changes:
- Decode kernel: restrict page splits to window range, apply per-token
  window mask, use effective sequence length for split-K heuristic
- Context kernel: skip pages before window in Phase 1, add per-query
  sliding window mask in both Phase 1 (full pages) and Phase 2
  (partial/causal pages), guard against NaN from -inf exponents
- triton_paged_mha_with_cache: thread sliding_window through to both
  kernels, add optional pre-allocated output buffer support
- Disable SDPA fast-path when sliding window is active
- Extract sliding_window constant from source attention node

MMLU: 75, GSM8K: 90

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…fixes, gather logits softcap, sliding window tests

- Enable MLIR elementwise fusion, gather_logits, and fuse_gemms transforms
  in gemma4_moe config; switch gemma4 models to world_size_1
- Register triton_paged ops in piecewise_utils for CUDA graph capture
- Add torch.cuda.synchronize after piecewise graph replay to prevent
  race conditions with non-default streams (e.g. fused_moe)
- Fix MLIR triton emitter: use tl.extra.cuda.libdevice for math ops
  (gelu, tanh, exp, softplus, pow); handle scalar/rank-0 tensor inputs;
  add AD_DUMP_KERNELS_DIR env var for kernel source inspection
- Fix gather_logits_before_lm_head to walk backward through post-lm_head
  ops (div, tanh, mul softcapping) to find the actual linear node
- Add sliding window attention tests for decode and context kernels
- Add softcapping LM head test for gather logits transform

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
…t Geglu in NVFP4 MoE

Pass window_left to fast_decode_plan in plan_generate_only so sliding
window attention is respected during CUDA-graph-captured decode. Add
early rejection of Gelu/Geglu in NVFP4 TRTLLM-Gen MoE since the
underlying kernel does not support it.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Migrate cuda_graph_batch_sizes to cuda_graph_config.batch_sizes and
add explicit max_batch_size to gemma3n config to preserve prior default.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

📝 Walkthrough

Walkthrough

Introduces Gemma 3n and Gemma 4 model support with custom PyTorch implementations, configuration files, and deployment cookbook. Adds shared KV cache and sliding-window attention infrastructure across all attention backends, multi-pool KV cache management, and updates compiler infrastructure for piecewise compilation and metadata handling.

Changes

Cohort / File(s) Summary
Gemma Model Implementations
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py, tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py
New custom PyTorch model implementations for Gemma 3n and Gemma 4 with RoPE embeddings, attention with shared-KV and sliding-window support, MoE routing/execution, and multimodal wrapper classes. Includes autoload hooks for weight unfusing and checkpoint adaptation.
Gemma Configuration & Registry
examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml, examples/auto_deploy/model_registry/configs/gemma4_moe.yaml, examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml, examples/auto_deploy/model_registry/models.yaml, tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
YAML deployment configs for Gemma 3n and Gemma 4 with attention/compile backends, KV cache settings, and transform toggles. Updates model registry and custom model exports.
Attention Backend - FlashInfer
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py
Added window_left parameter to planning, sliding-window conversion logic, read_cache_only flag, and supports_shared_kv() method. Extended get_constants signature to accept optional cache_config.
Attention Backend - Torch
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
Added layer_idx and shared_kv_source_layer_idx metadata parameters. Implemented read-cache-only attention paths with _write_generate_kv_cache helper. Extended get_constants to accept cache_config and added supports_shared_kv() method.
Attention Backend - Triton Paged
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
Added windowed (VSWA) support with kv_page_offset input and per-sequence page offset tracking. Extended kernels with sliding_window parameter, Phase 1/2 sliding-window masking, and out-buffer handling. Updated dispatch signatures and metadata constants.
Attention Backend - Triton & TRT-LLM
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Updated get_constants signature to accept optional cache_config. Refined layout extraction via extract_op_args helper. TRT-LLM now threads sink_token_length from cache config.
Attention Interface & Base Classes
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Added get_layer_idx() and get_shared_kv_source_layer_idx() classmethods for metadata extraction. Extended get_constants signature with cache_config parameter. Introduced supports_shared_kv() default implementation. Added window-group state tracking and per-group cache metadata tensor registration in SequenceInfo.
FLA Backends
tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py, tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py, tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py
Signature updates to get_constants methods to accept optional cache_config parameter for consistency across attention backends.
Mamba & MLA Backends
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_*.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mla/*
Extended get_constants signatures across Mamba and MLA descriptor classes to accept optional cache_config parameter for unified framework consistency.
MoE & Kernels
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py, tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Added GELU/GEGLU activation support. Introduced _normalize_trtllm_act_fn helper for consistent activation normalization in fused MoE kernels.
KV Cache Transform & Compiler
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py, tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py, tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py, tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py
Extended KV cache transform to support sliding-window, shared-KV aliasing, window groups, and per-group metadata placeholders. Updated graph insertion to thread cached_attn_op per node. Added synchronization in piecewise graph capture. Extended dynamic op registries.
KV Cache Manager - Multi-pool
tensorrt_llm/_torch/auto_deploy/shim/interface.py, tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py, tensorrt_llm/_torch/pyexecutor/resource_manager.py
Introduced MultiPoolKVCacheManager wrapper for managing multiple KV cache pools with different head_dim/dtype/layout groups. Updated executor to handle per-group cache indices, window offsets, and per-group metadata tensor staging. Added get_num_front_blocks_removed() method. Made model_config optional in VSWA helpers.
MLIR & Utilities
tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py, tensorrt_llm/_torch/auto_deploy/utils/node_utils.py, tensorrt_llm/_torch/auto_deploy/utils/_graph.py, tensorrt_llm/_torch/auto_deploy/export/export.py, tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
Updated Triton emission to use libdevice functions for math ops, refined input loading with scalar flags. Introduced get_op_schema() public utility for unified op schema resolution. Updated LM head gather transform to traverse through post-processing operations.
C++ Bindings
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Added Python binding for get_num_front_blocks_removed method on BaseKVCacheManager.
Documentation & Deployment
docs/source/models/supported-models.md, examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb
Added Gemma 4 to supported models list. Created end-to-end Jupyter cookbook demonstrating AutoDeploy workflow with Gemma 4 MoE, including server startup and OpenAI-compatible client usage.
Test Coverage
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py, tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py, tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py, tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py, tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py, tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py, tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py, tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Comprehensive test coverage for Gemma 3n/4 modeling equivalence, shared-KV cache behavior, sliding-window attention, multi-pool KV cache management, VSWA configuration, and end-to-end accuracy evaluation. Tests validate graph transformations, metadata wiring, and functional correctness.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~85 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.84% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main feature: dual-pool KV cache architecture with SWA block eviction support for Gemma4, which is the primary change across the changeset.
Description check ✅ Passed The PR description clearly explains the changes, provides specific test results, and addresses the problem being solved.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py (1)

432-456: ⚠️ Potential issue | 🔴 Critical

Fix the mutates_args declaration to match the actual tensor mutations.

torch_backend_mha_with_cache() calls _write_generate_kv_cache() and _update_kv_cache() (lines 498–500) which directly modify k_cache and v_cache via indexed assignment. However, the decorator declares mutates_args=(), creating a contract mismatch. This will cause torch.compile to misoptimize the cached attention computation.

🩹 Proposed fix
-@torch.library.custom_op("auto_deploy::torch_cached_attention_with_cache", mutates_args=())
+@torch.library.custom_op(
+    "auto_deploy::torch_cached_attention_with_cache",
+    mutates_args=("k_cache", "v_cache"),
+)
 def torch_backend_mha_with_cache(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`
around lines 432 - 456, The decorator for torch_backend_mha_with_cache
incorrectly declares mutates_args=() while the function mutates k_cache and
v_cache via _write_generate_kv_cache and _update_kv_cache; update the
`@torch.library.custom_op` on torch_backend_mha_with_cache to list the mutated
tensor arguments (k_cache and v_cache) in mutates_args so the op contract
matches the actual in-place updates and prevents torch.compile misoptimizations.
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)

977-992: ⚠️ Potential issue | 🟡 Minor

Use the union of groups in the returned KV stats.

In multi-pool mode, kv_managed here is only the last group processed by the loop above, so total_managed and the returned kv_managed count under-report earlier pools. The logging becomes misleading հենց when dual-pool mode is enabled; this should use kv_managed_all.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py` around lines 977 - 992,
The returned KV stats and total_managed use only the last group's kv_managed;
update the aggregation to use the union across all groups (kv_managed_all)
instead. Replace usages of kv_managed when computing total_managed and the
"kv_managed" return value with kv_managed_all (i.e., compute total_managed =
len(kv_managed_all) + ssm_managed_count + conv_managed_count and return
"kv_managed": len(kv_managed_all)); keep other derived counts (paged_total,
kv_total, paged_other, other_total) unchanged. Use the existing symbols
kv_managed_all, ssm_managed_count, conv_managed_count, total_managed, and the
return dict in this function to locate where to apply the change.
🧹 Nitpick comments (7)
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py (1)

135-137: Consider centralizing FlashInfer op invocation to reduce signature-drift risk.

These call-site updates are correct, but the repeated long positional argument list is brittle. A small local helper would make future arity/order changes safer.

Refactor sketch
+def _call_flashinfer_mha_with_cache(
+    q, k, v,
+    batch_info_host, qo_indptr_host, paged_kv_indptr, paged_kv_indptr_host,
+    paged_kv_indices, paged_kv_last_page_len, paged_kv_last_page_len_host,
+    seq_len_with_cache_host, batch_indices, positions, kv_cache,
+    k_scale, v_scale,
+):
+    return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
+        q, k, v,
+        batch_info_host, qo_indptr_host, paged_kv_indptr, paged_kv_indptr_host,
+        paged_kv_indices, paged_kv_last_page_len, paged_kv_last_page_len_host,
+        seq_len_with_cache_host, batch_indices, positions, kv_cache,
+        None, None, k_scale, v_scale,
+    )

Also applies to: 265-267, 396-398, 491-493, 625-627, 784-786, 892-894, 987-989

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py`
around lines 135 - 137, This test file repeats a long positional argument list
when invoking the FlashInfer attention op, which is brittle; create a small
local helper (e.g., call_flashinfer_attention or flashinfer_attention_helper) in
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py
that wraps the actual op invocation and accepts either a partial kwargs dict or
the same parameters with sensible defaults, then replace each repeated call-site
(the clusters around lines shown in the comment) to call that helper instead;
update occurrences referenced in the review (around the groups at 135-137,
265-267, 396-398, 491-493, 625-627, 784-786, 892-894, 987-989) so future
arity/order changes only need updating in the single helper.
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py (1)

493-496: Remove redundant q_pos assignment in the reference helper.

Line 493 is overwritten at Line 495, so it is dead code and can be dropped for clarity.

♻️ Proposed cleanup
-        q_pos = torch.arange(s_k - s_q + s_q, device=q.device)  # absolute positions
-        # For prefill: q_pos = [0..s_q-1], k_pos = [0..s_k-1]
+        # For prefill: q_pos = [0..s_q-1], k_pos = [0..s_k-1]
         q_pos = torch.arange(s_k - s_q, s_k, device=q.device)  # [s_q]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`
around lines 493 - 496, The first assignment to q_pos (q_pos = torch.arange(s_k
- s_q + s_q, device=q.device)) is dead code because it is immediately
overwritten by the later assignment; remove that redundant line and retain the
intended prefill assignment (q_pos = torch.arange(s_k - s_q, s_k,
device=q.device)) and k_pos assignment (k_pos = torch.arange(s_k,
device=q.device)); also update or keep the inline comment to reflect that q_pos
now represents absolute positions for prefill.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py (2)

446-449: Use pytest.skip instead of early return for CUDA-gated tests.

Early return exits silently without recording a skip. Using pytest.skip provides visibility in test reports and is the standard pattern for conditionally skipping tests.

♻️ Suggested fix
 def test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing():
     if not torch.cuda.is_available():
-        return
+        pytest.skip("CUDA not available")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py`
around lines 446 - 449, The test function
test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing
uses an early return when CUDA is unavailable; replace that return with a call
to pytest.skip("CUDA is not available") so the test is recorded as skipped, and
ensure pytest is imported at the top of the test module if not already present.

96-126: Unused variable batch should be prefixed with underscore.

Static analysis (RUF059) flags batch as unused. While this is a common pattern when unpacking tensor shapes, adding an underscore prefix silences the warning and signals intent.

♻️ Suggested fix
 def _manual_attention(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     sliding_window: int | None = None,
 ) -> torch.Tensor:
-    batch, seq_len_q, num_heads, _ = q.shape
+    _batch, seq_len_q, num_heads, _ = q.shape
     _, seq_len_k, num_kv_heads, _ = k.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py`
around lines 96 - 126, The variable `batch` in the `_manual_attention` function
is unused and triggers a static analysis warning; change the unpacking from
`batch, seq_len_q, num_heads, _ = q.shape` to prefix the unused variable (e.g.,
`_batch, seq_len_q, num_heads, _ = q.shape`) or otherwise rename it to `_batch`
to silence RUF059 and indicate it is intentionally unused; update any references
if you choose a different name.
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)

857-862: Consider using debug-level logging for SWA eviction details.

This log is emitted once per batch (first sequence only), but during high-throughput inference with many batches, INFO-level logs can still be noisy. Consider ad_logger.debug for routine operational details, reserving INFO for significant state changes.

♻️ Suggested change
                     if front_removed > 0 and i == 0:  # log once per batch, first seq only
-                        ad_logger.info(
+                        ad_logger.debug(
                             f"SWA eviction: group={group_idx} window={window_size} "
                             f"req={request.py_request_id} total_blocks={len(all_indices)} "
                             f"evicted={front_removed} active={num_active} offset={page_offset_g}"
                         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py` around lines 857 - 862,
The SWA eviction message currently uses ad_logger.info and can be noisy; change
the log level to debug in the SWA eviction block inside ad_executor.py (the
conditional that checks front_removed > 0 and i == 0) so routine eviction
details are emitted with ad_logger.debug instead of ad_logger.info, keeping the
same message text and context (group_idx, window_size, request.py_request_id,
len(all_indices), front_removed, num_active, page_offset_g) to preserve
diagnostic data.
tests/integration/defs/accuracy/test_llm_api_autodeploy.py (1)

982-984: Mutable class attribute should use a class property or be typed as ClassVar.

Static analysis flags EXTRA_EVALUATOR_KWARGS as a mutable default value for a class attribute (RUF012). While this dict is not mutated in practice, it's a minor code smell. Consider using ClassVar annotation or a property.

♻️ Optional fix using ClassVar annotation
+from typing import ClassVar
+
 class TestGemma4MoE(LlmapiAccuracyTestHarness):
     """Bench-run coverage for Gemma4 MoE via AutoDeploy."""

     MODEL_NAME = "google/gemma-4-26B-A4B-it"
-    EXTRA_EVALUATOR_KWARGS = {
+    EXTRA_EVALUATOR_KWARGS: ClassVar[dict[str, bool]] = {
         "apply_chat_template": True,
     }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py` around lines 982
- 984, EXTRA_EVALUATOR_KWARGS is defined as a mutable class attribute; change it
to a ClassVar-annotated constant or convert it into a property to satisfy static
analysis: annotate the symbol EXTRA_EVALUATOR_KWARGS as ClassVar[dict[str, Any]]
(import ClassVar and Any from typing) or replace it with a `@property` that
returns a fresh dict (e.g., def EXTRA_EVALUATOR_KWARGS(self) -> dict[str, Any]:
return {"apply_chat_template": True}); update the declaration and imports
accordingly.
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)

47-165: Type the new wrapper surface before more callers depend on it.

MultiPoolKVCacheManager is a new public API, but most of its methods/properties are unannotated. That makes it harder to use as a drop-in KVCacheManager replacement in type-checked code and obscures which methods intentionally diverge from single-pool behavior.

As per coding guidelines, "Always annotate functions with type hints" and "Externally called functions must have docstrings; function arguments should be documented, especially for class initializers".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py` around lines 47 - 165,
MultiPoolKVCacheManager is missing type hints and public docstrings; update the
class and its public surface to be a proper drop-in KVCacheManager replacement
by adding PEP484 type annotations and short docstrings: annotate __init__(self,
managers: List[KVCacheManager], primary_idx: int = 0) -> None and document
parameters, add return types for properties (e.g., impl -> Any or the actual
Impl type, tokens_per_block -> int, max_blocks_per_seq -> int,
blocks_in_primary_pool -> int), methods (get_num_free_blocks() -> int,
get_max_resource_count() -> int, get_needed_resource_to_completion(request:
RequestType) -> int or appropriate type, get_num_kv_blocks(num_tokens: int) ->
int, prepare_resources(scheduled_batch: ScheduledBatchType) -> None,
free_resources(request: RequestType, pin_on_release: bool = False) -> None,
update_resources(...)-> None, add_dummy_requests(request_ids: Sequence[str],
**kwargs) -> Any, shutdown() -> None, get_pool(group_idx: int) ->
KVCacheManager, num_pools -> int, max_concurrent_sequences -> int,
get_buffers(idx: int, kv_layout: str = "NHD") -> BufferType (or raise
NotImplementedError with a docstring explaining alternative),
event_buffer_max_size -> int, enable_block_reuse -> bool, enable_partial_reuse
-> bool, is_draft -> bool, kv_cache_pool_pointers -> PointerType,
kv_cache_pool_mapping -> MappingType, get_cache_indices(request: RequestType,
**kwargs) -> IndexType, store_blocks_for_reuse(request: RequestType, pin_blocks:
bool = False) -> None; include short docstrings on the class and each public
method/prop (at least __init__, get_buffers, get_pool, and get_cache_indices)
describing behavior and any divergence from single-pool KVCacheManager so
callers and type checkers can rely on it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb`:
- Line 160: The BASE_URL constant is set to a non-routable server bind address
("0.0.0.0"); change BASE_URL to use a routable client endpoint such as
"http://127.0.0.1:8000/v1" or "http://localhost:8000/v1" so client requests
target the local server correctly — locate the BASE_URL assignment in the
gemma_4_trtllm_cookbook notebook (the line containing BASE_URL =
"http://0.0.0.0:8000/v1") and replace the host portion accordingly.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py`:
- Around line 903-905: When you grow the base cache tensor, also grow the
per-group cache tensors so their capacities stay in sync: in the branch that
calls self._input_buffer.resize("cache_loc", estimated_capacity) update every
cache_loc_g* buffer (the attributes created/used by register_window_groups,
e.g., cache_loc_g0, cache_loc_g1, etc.) to the same estimated_capacity (and do
the same mirrored update in the other resize block around lines 919-924). Locate
the places that call self._input_buffer.resize("cache_loc", ...) and add a loop
or explicit resizes to update each cache_loc_g* and any cache_loc_per_group
bookkeeping so staging (cache_loc_per_group) uses the new capacity.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`:
- Around line 257-261: The decode-only path drops the sliding-window
(window_left) info, so wire it through: add a sliding_window/ window_left
parameter to prepare_flashinfer_metadata_host() and pass it into
plan_generate_only(), and then forward that value into
flashinfer.decode.fast_decode_plan() (or alternatively detect sliding_window and
route pure-decode batches to plan_decode()); update _to_flashinfer_window_left()
usage to compute the inclusive window_left and ensure
plan_generate_only()/fast_decode_plan() receive that value so SWA decode is
planned correctly.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1151-1163: The current code overwrites seq_len_with_cache with a
window-local length, causing masking to mix global and local coordinates
(kv_page_offset, q_positions_2d, first_q_kv_pos vs kv_base_pos) and leading to
stale/skipped tokens after eviction; instead, preserve the absolute
seq_len_with_cache for masking and only derive a separate local length for
page-iteration bounds: compute cache_len_capped_local =
torch.minimum(cache_len_raw, max_cached) and seq_len_with_cache_local =
cache_len_capped_local + q_lens, keep seq_len_with_cache unchanged, and pass/use
seq_len_with_cache_local solely where page/local bounds are required by
_paged_context_kernel or page-iteration logic (leave masking code using
seq_len_with_cache).

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Line 337: The constant list in TrtllmAttention.get_constants() now includes
sink_token_length as the sixth positional constant, but the cached-op signatures
incorrectly place the output buffer parameter (out) between out_scale and
sink_token_length; update the cached-op call/signature(s) so that
sink_token_length remains in the constants prefix (i.e., keep sink_token_length
as the sixth positional constant returned by TrtllmAttention.get_constants())
and move the out parameter after the constants in the cached-op invocation;
apply the same fix to the other cached-op occurrences referenced around the
blocks at the other locations (the similar cached-op signatures near the later
ranges).

In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py`:
- Around line 856-912: The code assumes managers is non-empty and will raise
IndexError for empty kv_groups; before constructing MultiPoolKVCacheManager or
indexing managers[primary_idx], check if managers is empty and handle the no-KV
fallback: if no managers, set self._kv_cache_manager to the appropriate fallback
(e.g., instantiate MambaHybridCacheManager or allocate local resources) instead
of creating MultiPoolKVCacheManager, ensuring the same fallback is used where
managers[primary_idx] would be accessed; update the branch that currently
assigns self._kv_cache_manager (the block that chooses between single manager
and MultiPoolKVCacheManager) to first handle len(managers) == 0, then
len(managers) == 1, then the multi-case, and reference the symbols
_kv_cache_manager, managers, primary_idx, MultiPoolKVCacheManager, and
MambaHybridCacheManager.

In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py`:
- Around line 1052-1058: get_op_schema currently picks an arbitrary schema from
multi-overload packets which is non-deterministic; update get_op_schema to (1)
accept explicit types (hint op: Union[torch._ops.OpOverloadPacket,
torch._ops.OpOverload]) via type hints, (2) check for a single-overload
attribute `_schema` first and return it, (3) when `_schemas` is present prefer
and return the `"default"` key if it exists, and (4) if multiple schemas exist
and no `"default"` is present raise a RuntimeError describing the ambiguous
OpOverloadPacket instead of using next(iter(...)); keep the function name
get_op_schema and callers intact.

In `@tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py`:
- Around line 1248-1261: The assertions depend on dict insertion order when
accessing interface._caches; instead, index the caches deterministically using
the resource names passed to add_resource ("kv_0" and "kv_1"). After
interface.initialize_resources(), replace the lookups that use
list(interface._caches.keys())[0/1] with direct access interface._caches["kv_0"]
and interface._caches["kv_1"] (used where kv_0 and kv_1 are assigned) so the
shape assertions reference the correct cache groups reliably.

---

Outside diff comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 432-456: The decorator for torch_backend_mha_with_cache
incorrectly declares mutates_args=() while the function mutates k_cache and
v_cache via _write_generate_kv_cache and _update_kv_cache; update the
`@torch.library.custom_op` on torch_backend_mha_with_cache to list the mutated
tensor arguments (k_cache and v_cache) in mutates_args so the op contract
matches the actual in-place updates and prevents torch.compile misoptimizations.

In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py`:
- Around line 977-992: The returned KV stats and total_managed use only the last
group's kv_managed; update the aggregation to use the union across all groups
(kv_managed_all) instead. Replace usages of kv_managed when computing
total_managed and the "kv_managed" return value with kv_managed_all (i.e.,
compute total_managed = len(kv_managed_all) + ssm_managed_count +
conv_managed_count and return "kv_managed": len(kv_managed_all)); keep other
derived counts (paged_total, kv_total, paged_other, other_total) unchanged. Use
the existing symbols kv_managed_all, ssm_managed_count, conv_managed_count,
total_managed, and the return dict in this function to locate where to apply the
change.

---

Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py`:
- Around line 857-862: The SWA eviction message currently uses ad_logger.info
and can be noisy; change the log level to debug in the SWA eviction block inside
ad_executor.py (the conditional that checks front_removed > 0 and i == 0) so
routine eviction details are emitted with ad_logger.debug instead of
ad_logger.info, keeping the same message text and context (group_idx,
window_size, request.py_request_id, len(all_indices), front_removed, num_active,
page_offset_g) to preserve diagnostic data.

In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py`:
- Around line 47-165: MultiPoolKVCacheManager is missing type hints and public
docstrings; update the class and its public surface to be a proper drop-in
KVCacheManager replacement by adding PEP484 type annotations and short
docstrings: annotate __init__(self, managers: List[KVCacheManager], primary_idx:
int = 0) -> None and document parameters, add return types for properties (e.g.,
impl -> Any or the actual Impl type, tokens_per_block -> int, max_blocks_per_seq
-> int, blocks_in_primary_pool -> int), methods (get_num_free_blocks() -> int,
get_max_resource_count() -> int, get_needed_resource_to_completion(request:
RequestType) -> int or appropriate type, get_num_kv_blocks(num_tokens: int) ->
int, prepare_resources(scheduled_batch: ScheduledBatchType) -> None,
free_resources(request: RequestType, pin_on_release: bool = False) -> None,
update_resources(...)-> None, add_dummy_requests(request_ids: Sequence[str],
**kwargs) -> Any, shutdown() -> None, get_pool(group_idx: int) ->
KVCacheManager, num_pools -> int, max_concurrent_sequences -> int,
get_buffers(idx: int, kv_layout: str = "NHD") -> BufferType (or raise
NotImplementedError with a docstring explaining alternative),
event_buffer_max_size -> int, enable_block_reuse -> bool, enable_partial_reuse
-> bool, is_draft -> bool, kv_cache_pool_pointers -> PointerType,
kv_cache_pool_mapping -> MappingType, get_cache_indices(request: RequestType,
**kwargs) -> IndexType, store_blocks_for_reuse(request: RequestType, pin_blocks:
bool = False) -> None; include short docstrings on the class and each public
method/prop (at least __init__, get_buffers, get_pool, and get_cache_indices)
describing behavior and any divergence from single-pool KVCacheManager so
callers and type checkers can rely on it.

In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py`:
- Around line 982-984: EXTRA_EVALUATOR_KWARGS is defined as a mutable class
attribute; change it to a ClassVar-annotated constant or convert it into a
property to satisfy static analysis: annotate the symbol EXTRA_EVALUATOR_KWARGS
as ClassVar[dict[str, Any]] (import ClassVar and Any from typing) or replace it
with a `@property` that returns a fresh dict (e.g., def
EXTRA_EVALUATOR_KWARGS(self) -> dict[str, Any]: return {"apply_chat_template":
True}); update the declaration and imports accordingly.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py`:
- Around line 446-449: The test function
test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing
uses an early return when CUDA is unavailable; replace that return with a call
to pytest.skip("CUDA is not available") so the test is recorded as skipped, and
ensure pytest is imported at the top of the test module if not already present.
- Around line 96-126: The variable `batch` in the `_manual_attention` function
is unused and triggers a static analysis warning; change the unpacking from
`batch, seq_len_q, num_heads, _ = q.shape` to prefix the unused variable (e.g.,
`_batch, seq_len_q, num_heads, _ = q.shape`) or otherwise rename it to `_batch`
to silence RUF059 and indicate it is intentionally unused; update any references
if you choose a different name.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py`:
- Around line 135-137: This test file repeats a long positional argument list
when invoking the FlashInfer attention op, which is brittle; create a small
local helper (e.g., call_flashinfer_attention or flashinfer_attention_helper) in
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py
that wraps the actual op invocation and accepts either a partial kwargs dict or
the same parameters with sensible defaults, then replace each repeated call-site
(the clusters around lines shown in the comment) to call that helper instead;
update occurrences referenced in the review (around the groups at 135-137,
265-267, 396-398, 491-493, 625-627, 784-786, 892-894, 987-989) so future
arity/order changes only need updating in the single helper.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 493-496: The first assignment to q_pos (q_pos = torch.arange(s_k -
s_q + s_q, device=q.device)) is dead code because it is immediately overwritten
by the later assignment; remove that redundant line and retain the intended
prefill assignment (q_pos = torch.arange(s_k - s_q, s_k, device=q.device)) and
k_pos assignment (k_pos = torch.arange(s_k, device=q.device)); also update or
keep the inline comment to reflect that q_pos now represents absolute positions
for prefill.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6358dc4d-9231-49e3-ba30-34e14606b639

📥 Commits

Reviewing files that changed from the base of the PR and between a1777fd and 9cf51a9.

📒 Files selected for processing (50)
  • cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
  • docs/source/models/supported-models.md
  • examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb
  • examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml
  • examples/auto_deploy/model_registry/configs/gemma4_moe.yaml
  • examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml
  • examples/auto_deploy/model_registry/models.yaml
  • tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
  • tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py
  • tensorrt_llm/_torch/auto_deploy/export/export.py
  • tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/auto_deploy/shim/interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py
  • tensorrt_llm/_torch/auto_deploy/utils/_graph.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py
  • tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
  • tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py
  • tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
  • tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py

"source": [
"from openai import OpenAI\n",
"\n",
"BASE_URL = \"http://0.0.0.0:8000/v1\"\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use a routable client endpoint instead of 0.0.0.0.

Line 160 should use 127.0.0.1 (or localhost) for client requests; 0.0.0.0 is intended for server bind, not client connect.

🔧 Proposed fix
-BASE_URL = "http://0.0.0.0:8000/v1"
+BASE_URL = "http://127.0.0.1:8000/v1"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"BASE_URL = \"http://0.0.0.0:8000/v1\"\n",
"BASE_URL = \"http://127.0.0.1:8000/v1\"\n",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb` at line 160,
The BASE_URL constant is set to a non-routable server bind address ("0.0.0.0");
change BASE_URL to use a routable client endpoint such as
"http://127.0.0.1:8000/v1" or "http://localhost:8000/v1" so client requests
target the local server correctly — locate the BASE_URL assignment in the
gemma_4_trtllm_cookbook notebook (the line containing BASE_URL =
"http://0.0.0.0:8000/v1") and replace the host portion accordingly.

Comment on lines 903 to 905
if estimated_capacity > cache_loc_capacity:
self._input_buffer.resize("cache_loc", estimated_capacity)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep cache_loc_g* in sync when cache_loc is resized.

register_window_groups() snapshots the current cache_loc capacity into each cache_loc_g*, but Line 904 only grows the base tensor. If max_num_tokens is smaller than max_batch_size * max_blocks_per_seq, staging cache_loc_per_group can start failing even though cache_loc itself was resized successfully.

Possible fix
         if estimated_capacity > cache_loc_capacity:
             self._input_buffer.resize("cache_loc", estimated_capacity)
+            for group_idx in range(1, self.num_window_groups):
+                self._input_buffer.resize(f"cache_loc_g{group_idx}", estimated_capacity)

Also applies to: 919-924

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py` around
lines 903 - 905, When you grow the base cache tensor, also grow the per-group
cache tensors so their capacities stay in sync: in the branch that calls
self._input_buffer.resize("cache_loc", estimated_capacity) update every
cache_loc_g* buffer (the attributes created/used by register_window_groups,
e.g., cache_loc_g0, cache_loc_g1, etc.) to the same estimated_capacity (and do
the same mirrored update in the other resize block around lines 919-924). Locate
the places that call self._input_buffer.resize("cache_loc", ...) and add a loop
or explicit resizes to update each cache_loc_g* and any cache_loc_per_group
bookkeeping so staging (cache_loc_per_group) uses the new capacity.

Comment on lines +257 to +261
def _to_flashinfer_window_left(sliding_window: Optional[int]) -> int:
"""Convert AD sliding-window size to FlashInfer's inclusive window_left contract."""
if sliding_window is None or sliding_window <= 0:
return -1
return sliding_window - 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file structure and find the relevant code
wc -l tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 143


🏁 Script executed:

# Let's read the file to understand the structure
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 4126


🏁 Script executed:

# Look at the _to_flashinfer_window_left function and surrounding context
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '250,270p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 879


🏁 Script executed:

# Find the plan_generate_only method
rg -n "def plan_generate_only" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -A 30

Repository: NVIDIA/TensorRT-LLM

Length of output: 1325


🏁 Script executed:

# Search for fast_decode_plan calls
rg -n "fast_decode_plan" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -B 2 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 441


🏁 Script executed:

# Look at PlanParams to see what fields it has
rg -n "class PlanParams" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -A 20

Repository: NVIDIA/TensorRT-LLM

Length of output: 655


🏁 Script executed:

# Look at plan_prefill and plan_decode methods
rg -n "def plan_prefill|def plan_decode|def plan" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 145


🏁 Script executed:

# Get plan_prefill implementation
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '159,200p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2195


🏁 Script executed:

# Get plan_decode implementation  
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '200,250p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2549


🏁 Script executed:

# Search for wrapper.plan calls to see how window_left is used
rg -n "wrapper\.plan" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 576


🏁 Script executed:

# Find prepare_flashinfer_metadata_host and see how it uses plan_generate_only
rg -n "prepare_flashinfer_metadata_host" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -A 50 | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 3141


🏁 Script executed:

# Check if fast_decode_plan is called elsewhere with window_left
rg -n "fast_decode_plan" tensorrt_llm/ -B 3 -A 10 --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1847


🏁 Script executed:

# Look at the full plan_generate_only method to see its complete signature
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '135,158p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1128


🏁 Script executed:

# Check if there are any comments or documentation about window_left support in flashinfer
rg -n "window_left" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 519


🏁 Script executed:

# Check the flashinfer_mha_with_cache function to see complete flow with window_left
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '334,460p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 5323


🏁 Script executed:

# Check if there's any indication that fast_decode_plan should support window_left
rg -n "fast_decode_plan" tensorrt_llm --type py -B 5 -A 15 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 2564


🏁 Script executed:

# Check if there's documentation or if we need to verify flashinfer API support for window_left in fast_decode_plan
rg -n "causal" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | grep -i "fast_decode"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Let's check the parameters passed to fast_decode_plan more carefully
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '135,160p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1178


🏁 Script executed:

# Check if there's a causal parameter being passed to fast_decode_plan
rg -n "causal" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 884


🏁 Script executed:

# Look at the full prepare_flashinfer_metadata_host signature and how it's called
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '313,330p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 743


🏁 Script executed:

# Check the flow: what calls prepare_flashinfer_metadata_host
rg -n "prepare_flashinfer_metadata_host" tensorrt_llm/ --type py -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 1548


🏁 Script executed:

# Let's verify if flashinfer.decode.fast_decode_plan supports window_left parameter
# by searching for any documentation or other usages in the repo
find tensorrt_llm -name "*.py" -type f -exec grep -l "flashinfer" {} \; | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 2510


🏁 Script executed:

# Check if there are any other calls to fast_decode_plan with window_left
rg -n "window_left" tensorrt_llm --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2720


🏁 Script executed:

# Look for any comments or documentation about flashinfer API usage
rg -n "flashinfer\.decode\.fast_decode_plan" tensorrt_llm -B 5 -A 15

Repository: NVIDIA/TensorRT-LLM

Length of output: 2564


🏁 Script executed:

# Check what lines 355-455 refer to in the file
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '355,380p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1173


🏁 Script executed:

# Look for where plan_generate_only should be getting window_left from
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '313,330p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 743


🏁 Script executed:

# Check the flow: when is prepare_flashinfer_metadata_host called vs flashinfer_mha_with_cache
rg -n "flashinfer_mha_with_cache|prepare_flashinfer_metadata" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | grep "def "

Repository: NVIDIA/TensorRT-LLM

Length of output: 258


🏁 Script executed:

# Look at the full function signature for flashinfer_mha_with_cache to understand the flow better
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '331,365p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1488


🏁 Script executed:

# Check if prepare_flashinfer_metadata_host needs sliding_window parameter
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '313,330p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 743


Wire window_left through the decode-only planning path.

The sliding-window support is incomplete. While window_left is properly threaded through plan_prefill() and plan_decode() when both phases exist, the decode-only fast path still drops it:

  • prepare_flashinfer_metadata_host() (line 313) has no sliding_window parameter, so it cannot pass window information to plan_generate_only()
  • plan_generate_only() calls flashinfer.decode.fast_decode_plan() (line 145) without window_left, causing SWA decode under CUDA-graph warmup/replay to be planned as full attention

Either:

  1. Add sliding_window parameter to prepare_flashinfer_metadata_host() and pass it through to plan_generate_only(), then to fast_decode_plan(), or
  2. Reroute pure-decode batches through the normal plan_decode() path instead of the plan_generate_only() fast path when sliding-window is active.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`
around lines 257 - 261, The decode-only path drops the sliding-window
(window_left) info, so wire it through: add a sliding_window/ window_left
parameter to prepare_flashinfer_metadata_host() and pass it into
plan_generate_only(), and then forward that value into
flashinfer.decode.fast_decode_plan() (or alternatively detect sliding_window and
route pure-decode batches to plan_decode()); update _to_flashinfer_window_left()
usage to compute the inclusive window_left and ensure
plan_generate_only()/fast_decode_plan() receive that value so SWA decode is
planned correctly.

Comment on lines 1151 to +1163
seq_len_with_cache = seq_len_with_cache_host[:num_prefill].to(q.device, non_blocking=True)
# For windowed cache_loc (VSWA), cap the cached-token portion of
# seq_len_with_cache to the actual pages available. Without this,
# the context kernel computes page iteration bounds from global
# seq_len, overflowing the windowed cache_loc.
# seq_len_with_cache = cache_len + q_len, where cache_len is the
# number of prior-cached tokens. Only cache_len needs capping.
q_lens = cu_seqlen[1 : num_prefill + 1] - cu_seqlen[:num_prefill]
page_counts = cu_num_pages[1 : num_prefill + 1] - cu_num_pages[:num_prefill]
max_cached = page_counts * kv_cache.shape[3] # pages × page_size
cache_len_raw = seq_len_with_cache - q_lens
cache_len_capped = torch.minimum(cache_len_raw, max_cached)
seq_len_with_cache = cache_len_capped + q_lens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don't rewrite seq_len_with_cache into window-local coordinates here.

kv_page_offset makes _paged_context_kernel interpret KV pages in absolute positions, but this cap turns seq_len_with_cache into a local length. After front eviction, q_positions_2d / first_q_kv_pos become local while kv_base_pos stays global, so the causal/SWA masks are evaluated in different coordinate systems. Prefill/extend after eviction can then admit stale tokens from the first retained page or skip valid later pages. Keep the absolute seq_len_with_cache for masking and derive local page bounds separately.

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 1160-1160: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 1151 - 1163, The current code overwrites seq_len_with_cache with a
window-local length, causing masking to mix global and local coordinates
(kv_page_offset, q_positions_2d, first_q_kv_pos vs kv_base_pos) and leading to
stale/skipped tokens after eviction; instead, preserve the absolute
seq_len_with_cache for masking and only derive a separate local length for
page-iteration bounds: compute cache_len_capped_local =
torch.minimum(cache_len_raw, max_cached) and seq_len_with_cache_local =
cache_len_capped_local + q_lens, keep seq_len_with_cache unchanged, and pass/use
seq_len_with_cache_local solely where page/local bounds are required by
_paged_context_kernel or page-iteration logic (leave masking code using
seq_len_with_cache).

kv_scale_quant_orig: float = 1.0,
out_scale: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
sink_token_length: int = 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep sink_token_length in the constant prefix.

TrtllmAttention.get_constants() now returns six positional constants, but both cached-op signatures still place out between out_scale and sink_token_length. That makes the sixth constant bind to out, so the transformed TRT-LLM path will either treat an int as the output buffer or never pass sink_token_length through.

🩹 Proposed fix
 def trtllm_mha_with_cache(
@@
     kv_scale_orig_quant: float = 1.0,
     kv_scale_quant_orig: float = 1.0,
     out_scale: Optional[torch.Tensor] = None,
-    out: Optional[torch.Tensor] = None,
     sink_token_length: int = 0,
+    out: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:
@@
 def trtllm_mha_with_cache_fake(
@@
     kv_scale_orig_quant: float = 1.0,
     kv_scale_quant_orig: float = 1.0,
     out_scale: Optional[torch.Tensor] = None,
-    out: Optional[torch.Tensor] = None,
     sink_token_length: int = 0,
+    out: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:

Also applies to: 535-535, 635-687

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py` at
line 337, The constant list in TrtllmAttention.get_constants() now includes
sink_token_length as the sixth positional constant, but the cached-op signatures
incorrectly place the output buffer parameter (out) between out_scale and
sink_token_length; update the cached-op call/signature(s) so that
sink_token_length remains in the constants prefix (i.e., keep sink_token_length
as the sixth positional constant returned by TrtllmAttention.get_constants())
and move the out parameter after the constants in the cached-op invocation;
apply the same fix to the other cached-op occurrences referenced around the
blocks at the other locations (the similar cached-op signatures near the later
ranges).

Comment on lines +856 to 912
# 2. Create one KVCacheManager per group
# SWA groups (window < max_seq_len) get fixed max_tokens.
# Full-attention groups get the remaining budget via max_tokens or free_gpu_mem_fraction.
managers: List[KVCacheManager] = []
primary_idx = 0 # index of the full-attention (largest-window) group
max_window_seen = 0

for group_idx, (kv_ref, kv_managed) in enumerate(kv_groups):
# Compute this group's token budget
group_max_tokens = self._compute_group_token_budget(
group_idx, kv_ref, kv_managed, kv_groups, max_tokens
)
group_config = self._prepare_kv_cache_config(group_max_tokens, kv_managed)
group_kwargs = self._build_kv_cache_kwargs(kv_ref, kv_managed, group_config)

# NOTE: SWA groups keep max_seq_len from config (NOT window_size).
# During prefill, sequences temporarily use up to max_seq_len blocks.
# max_attention_window evicts old blocks during decode, freeing them
# for new sequences. The SWA savings are throughput (more concurrent
# decode sequences), not peak memory reduction.

if has_state_resources and group_idx == 0:
group_kwargs["max_batch_size"] = self.info.max_num_state_slots
mgr, _ = self._create_and_assign_state_views(
group_kwargs,
ssm_ref,
ssm_managed,
ssm_spec,
conv_ref,
conv_managed,
conv_spec,
)
else:
mgr = KVCacheManager(**group_kwargs)

# 3. Create cache manager (delegate to state helper if state resources exist)
has_state_resources = ssm_managed or conv_managed
if has_state_resources:
# NOTE: +1 for cuda graph padding
kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots
self._kv_cache_manager, _ = self._create_and_assign_state_views(
kv_cache_kwargs,
ssm_ref,
ssm_managed,
ssm_spec,
conv_ref,
conv_managed,
conv_spec,
managers.append(mgr)
is_swa = self._is_swa_group(kv_managed)
ad_logger.info(
f"KV pool {group_idx}: {len(kv_managed)} layers, "
f"head_dim={kv_ref.head_dim}, "
f"max_attention_window={group_config.max_attention_window}, "
f"swa={is_swa}, "
f"max_tokens={group_max_tokens}"
)
else:
# No typed state resources - use pure KVCacheManager
self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs)

# 4. Store tuned config
self._kv_cache_config_tuned = kv_cache_config
# Track which group has the largest window (= primary for scheduler)
group_window = max(group_config.max_attention_window or [self.info.max_seq_len])
if group_window > max_window_seen:
max_window_seen = group_window
primary_idx = group_idx

# 3. Store manager (wrapper if multi-group, direct if single)
if len(managers) == 1:
self._kv_cache_manager = managers[0]
else:
self._kv_cache_manager = MultiPoolKVCacheManager(managers, primary_idx=primary_idx)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Handle the no-KV fallback before constructing the wrapper.

When kv_groups is empty, the loop on Lines 863-891 never appends to managers, but Line 911 still constructs a MultiPoolKVCacheManager and Lines 925-928 immediately index managers[primary_idx]. That regresses pure-state or fully unmanaged graphs into an IndexError instead of creating a MambaHybridCacheManager or just allocating local resources.

Also applies to: 925-928

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py` around lines 856 - 912,
The code assumes managers is non-empty and will raise IndexError for empty
kv_groups; before constructing MultiPoolKVCacheManager or indexing
managers[primary_idx], check if managers is empty and handle the no-KV fallback:
if no managers, set self._kv_cache_manager to the appropriate fallback (e.g.,
instantiate MambaHybridCacheManager or allocate local resources) instead of
creating MultiPoolKVCacheManager, ensuring the same fallback is used where
managers[primary_idx] would be accessed; update the branch that currently
assigns self._kv_cache_manager (the block that chooses between single manager
and MultiPoolKVCacheManager) to first handle len(managers) == 0, then
len(managers) == 1, then the multi-case, and reference the symbols
_kv_cache_manager, managers, primary_idx, MultiPoolKVCacheManager, and
MambaHybridCacheManager.

Comment on lines +1052 to 1058
def get_op_schema(op) -> torch.FunctionSchema:
"""Return the schema for an op or op overload packet."""
if hasattr(op, "_schemas"):
return next(iter(op._schemas.values()))
elif hasattr(op, "_schema"):
if hasattr(op, "_schema"):
return op._schema
raise RuntimeError(f"No schema found on op {op}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== get_op_schema callsites =="
rg -n --type=py '\bget_op_schema\('

echo
echo "== Direct packet-style calls without explicit .default =="
rg -n --type=py 'get_op_schema\(\s*torch\.ops\.[^)]*\)' | rg -v '\.default\)'

echo
echo "== Overload cardinality for direct torch.ops callsites =="
python - <<'PY'
import pathlib, re, torch

pattern = re.compile(r'get_op_schema\(\s*(torch\.ops\.[A-Za-z0-9_\.]+)\s*\)')
exprs = set()

for p in pathlib.Path(".").rglob("*.py"):
    text = p.read_text(encoding="utf-8", errors="ignore")
    exprs.update(m.group(1) for m in pattern.finditer(text))

for expr in sorted(exprs):
    if expr.endswith(".default"):
        continue
    try:
        op = eval(expr, {"torch": torch})
    except Exception as e:
        print(f"{expr} -> unresolved ({type(e).__name__})")
        continue

    schemas = getattr(op, "_schemas", None)
    if schemas is None:
        print(f"{expr} -> explicit overload/single schema")
    else:
        print(f"{expr} -> overload_count={len(schemas)} keys={list(schemas.keys())[:8]}")
PY

Repository: NVIDIA/TensorRT-LLM

Length of output: 808


🏁 Script executed:

# Inspect each call site context
echo "=== _graph.py:75 ==="
sed -n '70,80p' tensorrt_llm/_torch/auto_deploy/utils/_graph.py

echo
echo "=== export.py:279 ==="
sed -n '274,284p' tensorrt_llm/_torch/auto_deploy/export/export.py

echo
echo "=== kvcache.py:111 ==="
sed -n '106,116p' tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

echo
echo "=== attention_interface.py:71 ==="
sed -n '66,76p' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

echo
echo "=== node_utils.py:1061-1065 ==="
sed -n '1061,1065p' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2186


🏁 Script executed:

# Try simpler Python verification of torch.ops overload counts
python3 <<'PYEOF'
import sys
import torch

# Test a few common torch.ops to see what structure they have
test_ops = [
    ("torch.ops.aten.add", torch.ops.aten.add),
]

for name, op_ref in test_ops:
    has_schemas = hasattr(op_ref, "_schemas")
    has_schema = hasattr(op_ref, "_schema")
    print(f"{name}:")
    print(f"  has _schemas: {has_schemas}")
    print(f"  has _schema: {has_schema}")
    
    if has_schemas:
        schemas = op_ref._schemas
        print(f"  _schemas keys: {list(schemas.keys())[:10]}")
        print(f"  schema count: {len(schemas)}")
    if has_schema:
        print(f"  _schema: {op_ref._schema}")
PYEOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 180


🏁 Script executed:

# Search for how node.target gets assigned and what type it can be
echo "=== Searching for node.target assignments/usage ==="
rg -n 'node\.target' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1402


🏁 Script executed:

# Understand what type node.target can be when op="call_function"
echo "=== FX Node type hint documentation ==="
rg -A5 'node.op.*call_function' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

echo
echo "=== kvcache.py context for prep_meta_op ==="
sed -n '90,115p' tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

echo
echo "=== Search for call_function node creation ==="
rg -n 'call_function' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -15

Repository: NVIDIA/TensorRT-LLM

Length of output: 2503


🏁 Script executed:

# Look at the actual Node type and understand node.target semantics
echo "=== Node.target type from fx imports ==="
rg -n 'from torch' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -10

echo
echo "=== Check if there's type narrowing logic ==="
rg -B2 -A5 'node.op == "call_function"' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 588


🏁 Script executed:

# Search export.py for node.target context when get_op_schema is called
echo "=== export.py context around line 279 ==="
sed -n '265,290p' tensorrt_llm/_torch/auto_deploy/export/export.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 967


🏁 Script executed:

# Understand what is_op does and if it narrows the type
echo "=== is_op function definition ==="
sed -n '450,480p' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

echo
echo "=== Check if node.target from is_op match is always OpOverload ==="
rg -B5 -A10 'def is_op' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1661


🏁 Script executed:

# Check the actual behavior: what does get_prepare_extra_metadata_info return?
echo "=== attn_descriptor.get_prepare_extra_metadata_info return type ==="
rg -n 'get_prepare_extra_metadata_info' tensorrt_llm/_torch/auto_deploy/transform/library/

echo
echo "=== Look for its definition ==="
rg -B3 -A15 'def get_prepare_extra_metadata_info' tensorrt_llm/_torch/auto_deploy/

Repository: NVIDIA/TensorRT-LLM

Length of output: 13016


🏁 Script executed:

# Most importantly: has the _schemas branch ever been executed?
# Look for test coverage or usage of OpOverloadPacket with _schemas
echo "=== Search for _schemas usage ==="
rg -n '\._schemas' tensorrt_llm/_torch/auto_deploy/

echo
echo "=== Search for OpOverloadPacket direct passing ==="
rg -n 'OpOverloadPacket' tensorrt_llm/_torch/auto_deploy/

Repository: NVIDIA/TensorRT-LLM

Length of output: 5458


Avoid selecting an arbitrary schema from multi-overload packets in get_op_schema().

When node.target is an OpOverloadPacket (possible since is_op() explicitly matches both OpOverload and OpOverloadPacket), next(iter(op._schemas.values())) at line 1055 selects a non-deterministic schema, which can cause argument lookup failures in extract_op_args() and set_op_args().

The most direct risk is in export.py:279, where after is_op(node, moe_ops) (with OpOverloadPacket objects), node.target is passed directly without normalizing to an explicit overload.

Proposed fix:

  1. Reorder checks to prefer _schema (single overload) before _schemas
  2. When handling multi-overload packets, prefer "default" key if present
  3. Raise an error for ambiguous packets rather than selecting arbitrarily
  4. Add type hints to clarify accepted input types
Suggested implementation
-def get_op_schema(op) -> torch.FunctionSchema:
-    """Return the schema for an op or op overload packet."""
-    if hasattr(op, "_schemas"):
-        return next(iter(op._schemas.values()))
-    if hasattr(op, "_schema"):
-        return op._schema
+def get_op_schema(op: OpOverloadPacket | OpOverload) -> torch.FunctionSchema:
+    """Return the schema for an op overload or overload packet."""
+    if hasattr(op, "_schema"):
+        return op._schema
+    if hasattr(op, "_schemas"):
+        schemas = op._schemas
+        if "default" in schemas:
+            return schemas["default"]
+        if len(schemas) == 1:
+            return next(iter(schemas.values()))
+        raise RuntimeError(
+            f"Ambiguous overload packet {op}: multiple overloads {list(schemas.keys())}. "
+            "Pass an explicit overload (e.g., op.default)."
+        )
     raise RuntimeError(f"No schema found on op {op}")
 
-def _get_op_schema(node: Node):
+def _get_op_schema(node: Node) -> torch.FunctionSchema:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py` around lines 1052 -
1058, get_op_schema currently picks an arbitrary schema from multi-overload
packets which is non-deterministic; update get_op_schema to (1) accept explicit
types (hint op: Union[torch._ops.OpOverloadPacket, torch._ops.OpOverload]) via
type hints, (2) check for a single-overload attribute `_schema` first and return
it, (3) when `_schemas` is present prefer and return the `"default"` key if it
exists, and (4) if multiple schemas exist and no `"default"` is present raise a
RuntimeError describing the ambiguous OpOverloadPacket instead of using
next(iter(...)); keep the function name get_op_schema and callers intact.

Comment on lines +1248 to +1261
interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16))

interface.initialize_resources()

# Group 0 (head_dim=64): cache shape [..., 8, 32, 64]
kv_0 = interface._caches[list(interface._caches.keys())[0]]
assert kv_0.shape[-1] == 64
assert kv_0.shape[-3] == 8 # num_kv_heads

# Group 1 (head_dim=128): cache shape [..., 4, 32, 128]
kv_1 = interface._caches[list(interface._caches.keys())[1]]
assert kv_1.shape[-1] == 128
assert kv_1.shape[-3] == 4 # num_kv_heads
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid dict-order dependence in cache-shape assertions.

Line 1254 and Line 1259 rely on _caches key order. This can make the test flaky if insertion or registration order changes. Use the returned resource names from add_resource(...) for deterministic lookup.

✅ Suggested deterministic lookup
-    interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
-    interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16))
+    kv_0_name = interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
+    kv_1_name = interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16))
@@
-    kv_0 = interface._caches[list(interface._caches.keys())[0]]
+    kv_0 = interface._caches[kv_0_name]
@@
-    kv_1 = interface._caches[list(interface._caches.keys())[1]]
+    kv_1 = interface._caches[kv_1_name]
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 1254-1254: Prefer next(iter(interface._caches.keys())) over single element slice

Replace with next(iter(interface._caches.keys()))

(RUF015)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py`
around lines 1248 - 1261, The assertions depend on dict insertion order when
accessing interface._caches; instead, index the caches deterministically using
the resource names passed to add_resource ("kv_0" and "kv_1"). After
interface.initialize_resources(), replace the lookups that use
list(interface._caches.keys())[0/1] with direct access interface._caches["kv_0"]
and interface._caches["kv_1"] (used where kv_0 and kv_1 are assigned) so the
shape assertions reference the correct cache groups reliably.

@suyoggupta suyoggupta force-pushed the sg/swa branch 2 times, most recently from 9917059 to e4da73b Compare April 7, 2026 20:45
…sliding window attention

Adds dual-pool KV cache architecture for models with mixed attention types
(e.g., gemma4-26B with head_dim=256 sliding + head_dim=512 full attention).
Each head_dim group gets its own KVCacheManager pool with independent
max_attention_window, enabling SWA block eviction during decode.

Architecture: WindowPlan is the single source of truth for VSWA. It
separates logical attention-window routing (which layers share page tables)
from physical KV storage pooling (which layers share block pools). Both
graph wiring and runtime metadata emission derive from it, eliminating
predicate drift between the transform and executor.

Key changes:
- WindowPlan dataclass: per_layer_window, unique_windows, group indices,
  group_to_pool_idx mapping (decouples window groups from storage pools)
- MultiPoolKVCacheManager: delegates lifecycle to all storage pools
- _identify_managed_kv_groups: groups layers by (head_dim, dtype, kv_factor)
- Per-group cache_loc/cu_num_pages/kv_page_offset via VSWA graph wiring
- kv_page_offset in write kernel for window-relative page indexing
- kv_page_offset in context kernel for correct position-based masking
- cache_len capping from cu_num_pages in triton_paged_mha_with_cache
- get_num_front_blocks_removed C++ binding for SWA eviction tracking
- N-based proportional memory budget split across pools
- max_concurrent_sequences scheduler cap for multi-pool safety
- Unit tests for multi-group identification, dual-pool creation, and
  per-group max_attention_window scoping

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Copy link
Copy Markdown
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, here's my overall thought. I think this is a very nice design, and it makes a lot of sense. We will start needing to support different groups, including different groups for metadata.
That being said, maybe this is just personal taste, but I do think that we could change the design slightly in order to clean up some of the dependencies.

  1. To begin with, I would love for all of the cache-related to be encapsulated into the resource handler classes. This way, there's a very clean interface between individual layer-wise attention operators and our Cache Management Interface in the shim.
  2. Next up is what I think is the most important part of the change in the shim. Right now, instead of just naively collecting all the resource handlers, we might have to do some analysis of them on the fly. In particular, we have to on the fly analyze when we can put something into an existing KV manager and when we have to initialize a new group with a new KV manager. This design, regardless of sliding window or other features inside attention, should prove to be very scalable in the future as well.
  3. As part of that, we can also initialize a new group of metadata fields from the cache sequence interface in the attention interface. Ideally, we can carry that concept of a group over to the attention interface and use a standardized way to initialize a new group of metadat for that new KV Cache Manager.
  4. Now that all this information is in place, we can go back to the KV Cache Transform. When the KV Cache Transform requests certain metadata arguments, we can now return the nodes/metadata inputs that correspond to the particular group. This way, we can dynamically insert the correct group.
  5. Now the final step to tie it all together is, of course, in prepare inputs, where for each of the KV cache managers (where each KV cache manager corresponds to a different group) we tie it all together and we prep the metadata for all the groups and pass it into the attention interface.

What do you think of this, given that the change we're introducing here is very heavy? It might be worth digging a little deeper.

@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
def get_constants(
cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can the sink token length not be extracted from the node just like the other constants?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a full device sync.

Sync-ing a specified stream will be a more fine-grain approach and avoids destroying GPU pipelining. I think we should do a stream-sync here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May you explain the rationale for changing the comment here. Thank you.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If kv_managed is empty, this returns true. Need to guard with assertion here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If kv_idx does not overlap, the function also returns true. We should guard this as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should account for max_batch_size here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants